{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Custom-Networks.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "jZgxBq4s3ndn",
        "0klE2QptdcaG",
        "an_z0dDmflp1",
        "lnrHPajlwHHM",
        "G31YQExqwNNc",
        "r9WgeoI3wo-o"
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "1b6f39faaa8544c49e9b84ecfd67ca41": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_7ddf2ce051e14b249a7b8dfdcd2b08c1",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_7e948458c6f348adb9b5dcc19594896d",
              "IPY_MODEL_692f5b26181b402690cd682f4ef389ee"
            ]
          }
        },
        "7ddf2ce051e14b249a7b8dfdcd2b08c1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "7e948458c6f348adb9b5dcc19594896d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_012265f6cb7e4516a8a6cf5318a30ce2",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_e0aff2702cdd47f795f2d2a233ba8248"
          }
        },
        "692f5b26181b402690cd682f4ef389ee": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_90782a422602433fb529a31828ea711a",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 9920512/? [00:20&lt;00:00, 358705.03it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_769e755461c64b16a37b279cbf02cb5b"
          }
        },
        "012265f6cb7e4516a8a6cf5318a30ce2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "e0aff2702cdd47f795f2d2a233ba8248": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "90782a422602433fb529a31828ea711a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "769e755461c64b16a37b279cbf02cb5b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "dfe3ba30ba8846ad9a533089bfee6987": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_b8462eea1e2b49df87b160453da3fb8e",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_07918b13b9e64c219ca8e2f07c1193be",
              "IPY_MODEL_1eb02d1769b0438582cae93bf765327e"
            ]
          }
        },
        "b8462eea1e2b49df87b160453da3fb8e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "07918b13b9e64c219ca8e2f07c1193be": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_d521e62b249d470b863a70ac156341f6",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_83ef5b03b73c45b3b5bb8c18eabc72de"
          }
        },
        "1eb02d1769b0438582cae93bf765327e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_787f5b89f95c4255aef60fa6f91d3b22",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 32768/? [00:00&lt;00:00, 99630.91it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_07d4d1e25e2a41ec97196a1a289199c5"
          }
        },
        "d521e62b249d470b863a70ac156341f6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "83ef5b03b73c45b3b5bb8c18eabc72de": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "787f5b89f95c4255aef60fa6f91d3b22": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "07d4d1e25e2a41ec97196a1a289199c5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "2d599bc13c8543da9ec1c82b11ab6a69": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_464e56783ced4a039c4f5fcc1fd290e1",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_678e744efdf9417f8d55151b76bbece1",
              "IPY_MODEL_263e5f6844bd4f7c81876d4093755c63"
            ]
          }
        },
        "464e56783ced4a039c4f5fcc1fd290e1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "678e744efdf9417f8d55151b76bbece1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_0cc06bcc77df4921b6b8d9007b42ac05",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_38a4bb4a15e3405a8007e9691f69d825"
          }
        },
        "263e5f6844bd4f7c81876d4093755c63": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_f7fb8781ff2646a9bd86440877ecc101",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 1654784/? [00:18&lt;00:00, 616186.61it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_6998ba54432c4eb38bab6792887c7ccc"
          }
        },
        "0cc06bcc77df4921b6b8d9007b42ac05": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "38a4bb4a15e3405a8007e9691f69d825": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f7fb8781ff2646a9bd86440877ecc101": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "6998ba54432c4eb38bab6792887c7ccc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "035b97b7a7e74191ac92016c697ba375": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_726c4cf5f5b64437a6327c48236aeb42",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_18fa0b70765842b2abf83c1b4793816c",
              "IPY_MODEL_ade571a07230429d9f5ee996b9fd8760"
            ]
          }
        },
        "726c4cf5f5b64437a6327c48236aeb42": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "18fa0b70765842b2abf83c1b4793816c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_82a261efb72447feae936e85cd4bbc72",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_75ecd0780a0b435eb91ddb9c11dbf0d5"
          }
        },
        "ade571a07230429d9f5ee996b9fd8760": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_db052fa05c2d4224a87e2df722895b24",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 8192/? [00:00&lt;00:00, 15616.38it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_8c682e975d654977bc80c939cd4a40d8"
          }
        },
        "82a261efb72447feae936e85cd4bbc72": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "75ecd0780a0b435eb91ddb9c11dbf0d5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "db052fa05c2d4224a87e2df722895b24": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "8c682e975d654977bc80c939cd4a40d8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_76JLWRxku5p",
        "colab_type": "text"
      },
      "source": [
        "# Creating Custom Networks for Multi-Class Classification\n",
        "\n",
        "This tutorial demonstrates how to define, train, and use different models for multi-class classification. We will reuse most of the code from the [Logistic Regression](Logistic_Regression.html) tutorial so if you haven't gone through that, consider reviewing it first.\n",
        "\n",
        "Note that this tutorial includes a demonstration on how to build and train a simple convolutional neural network and running this colab on CPU may take some time. Therefore, we recommend to run this colab on GPU (select ``GPU`` on the menu ``Runtime`` -> ``Change runtime type`` -> ``Hardware accelerator`` if hardware accelerator is not set to GPU)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "by3H5pItdUj8",
        "colab_type": "text"
      },
      "source": [
        "## Import Modules\n",
        "\n",
        "We start by importing the modules we will use in our code."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "H4xxyuS0czl_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%pip --quiet install objax\n",
        "\n",
        "import os\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "import objax\n",
        "from objax.util import EasyDict\n",
        "from objax.zoo.dnnet import DNNet"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wiMCbnmtdKAE",
        "colab_type": "text"
      },
      "source": [
        "## Load the data\n",
        "\n",
        "Next, we will load the \"[MNIST](http://yann.lecun.com/exdb/mnist/)\" dataset from [TensorFlow DataSets](https://www.tensorflow.org/datasets/api_docs/python/tfds). This dataset contains handwritten digits (i.e., numbers between 0 and 9) and to correctly identify each handwritten digit. \n",
        "\n",
        "The ``prepare`` method pads 2 pixels to the left, right, top, and bottom of each image to resize into 32 x 32 pixes. While MNIST images are grayscale the ``prepare`` method expands each image to three color channels to demonstrate the process of working with color images. The same method also rescales each pixel value to [-1, 1], and converts the image to (N, C, H, W) format."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YEBT7ZTGdJb_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Data: train has 60000 images - test has 10000 images\n",
        "# Each image is resized and converted to 32 x 32 x 3\n",
        "DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')\n",
        "data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))\n",
        "\n",
        "def prepare(x):\n",
        "  \"\"\"Pads 2 pixels to the left, right, top, and bottom of each image, scales pixel value to [-1, 1], and converts to NCHW format.\"\"\"\n",
        "  s = x.shape\n",
        "  x_pad = np.zeros((s[0], 32, 32, 1))\n",
        "  x_pad[:, 2:-2, 2:-2, :] = x\n",
        "  return objax.util.image.nchw(\n",
        "      np.concatenate([x_pad.astype('f') * (1 / 127.5) - 1] * 3, axis=-1))\n",
        "\n",
        "train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])\n",
        "test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])\n",
        "ndim = train.image.shape[-1]\n",
        "\n",
        "del data"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jZgxBq4s3ndn",
        "colab_type": "text"
      },
      "source": [
        "## Deep Neural Network Model\n",
        "\n",
        "Objax offers many predefined models that we can use for classification. One example is the ``objax.zoo.DNNet`` model comprising multiple fully connected layers with configurable size and activation functions. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9KeRqkHq3nFg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "dnn_layer_sizes = 3072, 128, 10\n",
        "dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0klE2QptdcaG",
        "colab_type": "text"
      },
      "source": [
        "## Custom Model Definition\n",
        "\n",
        "Alternatively, we can define a new model customized to our machine learning task. We demonstrate this process by defining a convolutional network (ConvNet) from scratch. \n",
        "\n",
        "We use ``objax.nn.Sequential`` to compose multiple layers of convolution (``objax.nn.Conv2D``), batch normalization (``objax.nn.BatchNorm2D``), ReLU (``objax.functional.relu``), Max Pooling (``objax.functional.max_pool_2d``), Average Pooling (``jax.mean``), and Linear (``objax.nn.Linear``) layers.\n",
        "\n",
        "Since [batch normalization layer](https://arxiv.org/abs/1502.03167) behaves differently at training and at prediction, we pass the ``training`` flag to a ``__call__`` function of ``ConvNet`` class. We also use the flag to output logits at training and probability at prediction."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DqvdMVI75fl1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class ConvNet(objax.Module):\n",
        "  \"\"\"ConvNet implementation.\"\"\"\n",
        "\n",
        "  def __init__(self, nin, nclass):\n",
        "    \"\"\"Define 3 blocks of conv-bn-relu-conv-bn-relu followed by linear layer.\"\"\"\n",
        "    self.conv_block1 = objax.nn.Sequential([objax.nn.Conv2D(nin, 16, 3, use_bias=False),\n",
        "                                            objax.nn.BatchNorm2D(16),\n",
        "                                            objax.functional.relu,\n",
        "                                            objax.nn.Conv2D(16, 16, 3, use_bias=False),\n",
        "                                            objax.nn.BatchNorm2D(16),\n",
        "                                            objax.functional.relu])\n",
        "    self.conv_block2 = objax.nn.Sequential([objax.nn.Conv2D(16, 32, 3, use_bias=False),\n",
        "                                            objax.nn.BatchNorm2D(32),\n",
        "                                            objax.functional.relu,\n",
        "                                            objax.nn.Conv2D(32, 32, 3, use_bias=False),\n",
        "                                            objax.nn.BatchNorm2D(32),\n",
        "                                            objax.functional.relu])\n",
        "    self.conv_block3 = objax.nn.Sequential([objax.nn.Conv2D(32, 64, 3, use_bias=False),\n",
        "                                            objax.nn.BatchNorm2D(64),\n",
        "                                            objax.functional.relu,\n",
        "                                            objax.nn.Conv2D(64, 64, 3, use_bias=False),\n",
        "                                            objax.nn.BatchNorm2D(64),\n",
        "                                            objax.functional.relu])\n",
        "    self.linear = objax.nn.Linear(64, nclass)\n",
        "\n",
        "  def __call__(self, x, training):\n",
        "    x = self.conv_block1(x, training=training)\n",
        "    x = objax.functional.max_pool_2d(x, size=2, strides=2)\n",
        "    x = self.conv_block2(x, training=training)\n",
        "    x = objax.functional.max_pool_2d(x, size=2, strides=2)\n",
        "    x = self.conv_block3(x, training=training)\n",
        "    x = x.mean((2, 3))\n",
        "    x = self.linear(x)\n",
        "    return x\n",
        "    "
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DMrrT3SudevT",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 578
        },
        "outputId": "5fcd6c36-e006-4187-fa3f-9597d5ba35aa"
      },
      "source": [
        "cnn_model = ConvNet(nin=3, nclass=10)\n",
        "print(cnn_model.vars())"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(ConvNet).conv_block1(Sequential)[0](Conv2D).w                      432 (3, 3, 3, 16)\n",
            "(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean       16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var        16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta               16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma              16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[3](Conv2D).w                     2304 (3, 3, 16, 16)\n",
            "(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean       16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var        16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta               16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma              16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[0](Conv2D).w                     4608 (3, 3, 16, 32)\n",
            "(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean       32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var        32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta               32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma              32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[3](Conv2D).w                     9216 (3, 3, 32, 32)\n",
            "(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean       32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var        32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta               32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma              32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[0](Conv2D).w                    18432 (3, 3, 32, 64)\n",
            "(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean       64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var        64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta               64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma              64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[3](Conv2D).w                    36864 (3, 3, 64, 64)\n",
            "(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean       64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var        64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta               64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma              64 (1, 64, 1, 1)\n",
            "(ConvNet).linear(Linear).b                                           10 (10,)\n",
            "(ConvNet).linear(Linear).w                                          640 (64, 10)\n",
            "+Total(32)                                                        73402\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "an_z0dDmflp1",
        "colab_type": "text"
      },
      "source": [
        "## Model Training and Evaluation\n",
        "\n",
        "The ``train_model`` method combines all the parts of defining the loss function, gradient descent, training loop, and evaluation. It takes the ``model`` as a parameter so it can be reused with the two models we defined earlier.\n",
        "\n",
        "Unlike the Logistic Regression tutorial we use the ``objax.functional.loss.cross_entropy_logits_sparse`` because we perform multi-class classification. The optimizer, gradient descent operation, and training loop remain the same. \n",
        "\n",
        "The ``DNNet`` model expects flattened images whereas ``ConvNet`` images in (C, H, W) format. The ``flatten_image`` method prepares images before passing them to the model. \n",
        "\n",
        "When using the model for inference we apply the ``objax.functional.softmax`` method to compute the probability distribution from the model's logits. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yN7aRrzw6OnZ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Settings\n",
        "lr = 0.03  # learning rate\n",
        "batch = 128\n",
        "epochs = 100\n",
        "\n",
        "# Train loop\n",
        "\n",
        "def train_model(model):\n",
        "\n",
        "  def predict(model, x):\n",
        "    \"\"\"\"\"\" \n",
        "    return objax.functional.softmax(model(x,  training=False))\n",
        "    \n",
        "  def flatten_image(x):\n",
        "    \"\"\"Flatten the image before passing it to the DNN.\"\"\"\n",
        "    if isinstance(model, DNNet):\n",
        "      return objax.functional.flatten(x)\n",
        "    else:\n",
        "      return x\n",
        "  \n",
        "  opt = objax.optimizer.Momentum(model.vars())\n",
        "\n",
        "  # Cross Entropy Loss\n",
        "  def loss(x, label):\n",
        "    return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()\n",
        "\n",
        "  gv = objax.GradValues(loss, model.vars())\n",
        "  def train_op(x, label):\n",
        "    g, v = gv(x, label)  # returns gradients, loss\n",
        "    opt(lr, g)\n",
        "    return v\n",
        "\n",
        "  train_op = objax.Jit(train_op, gv.vars() + opt.vars())  \n",
        "  \n",
        "  for epoch in range(epochs):\n",
        "    avg_loss = 0\n",
        "    # randomly shuffle training data\n",
        "    shuffle_idx = np.random.permutation(train.image.shape[0])\n",
        "    for it in range(0, train.image.shape[0], batch):\n",
        "      sel = shuffle_idx[it: it + batch]\n",
        "      avg_loss += float(train_op(flatten_image(train.image[sel]), train.label[sel])[0]) * len(sel)\n",
        "    avg_loss /= it + len(sel)\n",
        "\n",
        "    # Eval\n",
        "    accuracy = 0\n",
        "    for it in range(0, test.image.shape[0], batch):\n",
        "      x, y = test.image[it: it + batch], test.label[it: it + batch]\n",
        "      accuracy += (np.argmax(predict(model, flatten_image(x)), axis=1) == y).sum()  \n",
        "    accuracy /= test.image.shape[0]\n",
        "    print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lnrHPajlwHHM",
        "colab_type": "text"
      },
      "source": [
        "## Training the DNN Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tZfMOlGPwudp",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "b4bb1b23-677d-49d8-e4b6-2c5e83304e6d"
      },
      "source": [
        "train_model(dnn_model)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0001  Loss 2.39  Accuracy 56.31\n",
            "Epoch 0002  Loss 1.24  Accuracy 74.19\n",
            "Epoch 0003  Loss 0.75  Accuracy 84.91\n",
            "Epoch 0004  Loss 0.56  Accuracy 83.10\n",
            "Epoch 0005  Loss 0.48  Accuracy 86.42\n",
            "Epoch 0006  Loss 0.43  Accuracy 89.00\n",
            "Epoch 0007  Loss 0.41  Accuracy 89.62\n",
            "Epoch 0008  Loss 0.39  Accuracy 89.82\n",
            "Epoch 0009  Loss 0.37  Accuracy 90.04\n",
            "Epoch 0010  Loss 0.36  Accuracy 89.55\n",
            "Epoch 0011  Loss 0.35  Accuracy 90.53\n",
            "Epoch 0012  Loss 0.35  Accuracy 90.64\n",
            "Epoch 0013  Loss 0.34  Accuracy 90.85\n",
            "Epoch 0014  Loss 0.33  Accuracy 90.87\n",
            "Epoch 0015  Loss 0.33  Accuracy 91.02\n",
            "Epoch 0016  Loss 0.32  Accuracy 91.35\n",
            "Epoch 0017  Loss 0.32  Accuracy 91.35\n",
            "Epoch 0018  Loss 0.31  Accuracy 91.50\n",
            "Epoch 0019  Loss 0.31  Accuracy 91.57\n",
            "Epoch 0020  Loss 0.31  Accuracy 91.75\n",
            "Epoch 0021  Loss 0.30  Accuracy 91.47\n",
            "Epoch 0022  Loss 0.30  Accuracy 88.14\n",
            "Epoch 0023  Loss 0.30  Accuracy 91.82\n",
            "Epoch 0024  Loss 0.30  Accuracy 91.92\n",
            "Epoch 0025  Loss 0.29  Accuracy 92.03\n",
            "Epoch 0026  Loss 0.29  Accuracy 92.04\n",
            "Epoch 0027  Loss 0.29  Accuracy 92.11\n",
            "Epoch 0028  Loss 0.29  Accuracy 92.11\n",
            "Epoch 0029  Loss 0.29  Accuracy 92.18\n",
            "Epoch 0030  Loss 0.28  Accuracy 92.24\n",
            "Epoch 0031  Loss 0.28  Accuracy 92.36\n",
            "Epoch 0032  Loss 0.28  Accuracy 92.17\n",
            "Epoch 0033  Loss 0.28  Accuracy 92.42\n",
            "Epoch 0034  Loss 0.28  Accuracy 92.42\n",
            "Epoch 0035  Loss 0.27  Accuracy 92.47\n",
            "Epoch 0036  Loss 0.27  Accuracy 92.50\n",
            "Epoch 0037  Loss 0.27  Accuracy 92.49\n",
            "Epoch 0038  Loss 0.27  Accuracy 92.58\n",
            "Epoch 0039  Loss 0.26  Accuracy 92.56\n",
            "Epoch 0040  Loss 0.26  Accuracy 92.56\n",
            "Epoch 0041  Loss 0.26  Accuracy 92.77\n",
            "Epoch 0042  Loss 0.26  Accuracy 92.72\n",
            "Epoch 0043  Loss 0.26  Accuracy 92.80\n",
            "Epoch 0044  Loss 0.25  Accuracy 92.85\n",
            "Epoch 0045  Loss 0.25  Accuracy 92.90\n",
            "Epoch 0046  Loss 0.25  Accuracy 92.96\n",
            "Epoch 0047  Loss 0.25  Accuracy 93.00\n",
            "Epoch 0048  Loss 0.25  Accuracy 92.82\n",
            "Epoch 0049  Loss 0.25  Accuracy 93.18\n",
            "Epoch 0050  Loss 0.24  Accuracy 93.09\n",
            "Epoch 0051  Loss 0.24  Accuracy 92.94\n",
            "Epoch 0052  Loss 0.24  Accuracy 93.20\n",
            "Epoch 0053  Loss 0.24  Accuracy 93.26\n",
            "Epoch 0054  Loss 0.23  Accuracy 93.21\n",
            "Epoch 0055  Loss 0.24  Accuracy 93.42\n",
            "Epoch 0056  Loss 0.23  Accuracy 93.35\n",
            "Epoch 0057  Loss 0.23  Accuracy 93.36\n",
            "Epoch 0058  Loss 0.23  Accuracy 93.56\n",
            "Epoch 0059  Loss 0.23  Accuracy 93.54\n",
            "Epoch 0060  Loss 0.22  Accuracy 93.39\n",
            "Epoch 0061  Loss 0.23  Accuracy 93.56\n",
            "Epoch 0062  Loss 0.22  Accuracy 93.74\n",
            "Epoch 0063  Loss 0.22  Accuracy 93.68\n",
            "Epoch 0064  Loss 0.22  Accuracy 93.72\n",
            "Epoch 0065  Loss 0.22  Accuracy 93.76\n",
            "Epoch 0066  Loss 0.22  Accuracy 93.87\n",
            "Epoch 0067  Loss 0.21  Accuracy 93.89\n",
            "Epoch 0068  Loss 0.21  Accuracy 93.96\n",
            "Epoch 0069  Loss 0.21  Accuracy 93.90\n",
            "Epoch 0070  Loss 0.21  Accuracy 93.99\n",
            "Epoch 0071  Loss 0.21  Accuracy 94.02\n",
            "Epoch 0072  Loss 0.21  Accuracy 93.86\n",
            "Epoch 0073  Loss 0.21  Accuracy 94.06\n",
            "Epoch 0074  Loss 0.21  Accuracy 94.14\n",
            "Epoch 0075  Loss 0.20  Accuracy 94.31\n",
            "Epoch 0076  Loss 0.20  Accuracy 94.14\n",
            "Epoch 0077  Loss 0.20  Accuracy 94.15\n",
            "Epoch 0078  Loss 0.20  Accuracy 94.10\n",
            "Epoch 0079  Loss 0.20  Accuracy 94.16\n",
            "Epoch 0080  Loss 0.20  Accuracy 94.28\n",
            "Epoch 0081  Loss 0.20  Accuracy 94.30\n",
            "Epoch 0082  Loss 0.20  Accuracy 94.28\n",
            "Epoch 0083  Loss 0.19  Accuracy 94.37\n",
            "Epoch 0084  Loss 0.19  Accuracy 94.33\n",
            "Epoch 0085  Loss 0.19  Accuracy 94.31\n",
            "Epoch 0086  Loss 0.19  Accuracy 94.25\n",
            "Epoch 0087  Loss 0.19  Accuracy 94.37\n",
            "Epoch 0088  Loss 0.19  Accuracy 94.38\n",
            "Epoch 0089  Loss 0.19  Accuracy 94.35\n",
            "Epoch 0090  Loss 0.19  Accuracy 94.38\n",
            "Epoch 0091  Loss 0.19  Accuracy 94.41\n",
            "Epoch 0092  Loss 0.19  Accuracy 94.46\n",
            "Epoch 0093  Loss 0.19  Accuracy 94.53\n",
            "Epoch 0094  Loss 0.18  Accuracy 94.47\n",
            "Epoch 0095  Loss 0.18  Accuracy 94.54\n",
            "Epoch 0096  Loss 0.18  Accuracy 94.65\n",
            "Epoch 0097  Loss 0.18  Accuracy 94.56\n",
            "Epoch 0098  Loss 0.18  Accuracy 94.60\n",
            "Epoch 0099  Loss 0.18  Accuracy 94.63\n",
            "Epoch 0100  Loss 0.18  Accuracy 94.46\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G31YQExqwNNc",
        "colab_type": "text"
      },
      "source": [
        "## Training the ConvNet Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Pz0TmizfSipf",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "008b979c-bde4-4fe0-81ff-5962f9fc0ec1"
      },
      "source": [
        "train_model(cnn_model)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0001  Loss 0.27  Accuracy 27.08\n",
            "Epoch 0002  Loss 0.05  Accuracy 41.07\n",
            "Epoch 0003  Loss 0.03  Accuracy 67.77\n",
            "Epoch 0004  Loss 0.03  Accuracy 73.31\n",
            "Epoch 0005  Loss 0.02  Accuracy 90.30\n",
            "Epoch 0006  Loss 0.02  Accuracy 93.10\n",
            "Epoch 0007  Loss 0.02  Accuracy 95.98\n",
            "Epoch 0008  Loss 0.01  Accuracy 98.77\n",
            "Epoch 0009  Loss 0.01  Accuracy 96.58\n",
            "Epoch 0010  Loss 0.01  Accuracy 99.12\n",
            "Epoch 0011  Loss 0.01  Accuracy 98.88\n",
            "Epoch 0012  Loss 0.01  Accuracy 98.64\n",
            "Epoch 0013  Loss 0.01  Accuracy 98.66\n",
            "Epoch 0014  Loss 0.00  Accuracy 98.38\n",
            "Epoch 0015  Loss 0.00  Accuracy 99.15\n",
            "Epoch 0016  Loss 0.00  Accuracy 97.50\n",
            "Epoch 0017  Loss 0.00  Accuracy 98.98\n",
            "Epoch 0018  Loss 0.00  Accuracy 98.94\n",
            "Epoch 0019  Loss 0.00  Accuracy 98.56\n",
            "Epoch 0020  Loss 0.00  Accuracy 99.06\n",
            "Epoch 0021  Loss 0.00  Accuracy 99.26\n",
            "Epoch 0022  Loss 0.00  Accuracy 99.30\n",
            "Epoch 0023  Loss 0.00  Accuracy 99.18\n",
            "Epoch 0024  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0025  Loss 0.00  Accuracy 99.34\n",
            "Epoch 0026  Loss 0.00  Accuracy 99.24\n",
            "Epoch 0027  Loss 0.00  Accuracy 99.38\n",
            "Epoch 0028  Loss 0.00  Accuracy 99.43\n",
            "Epoch 0029  Loss 0.00  Accuracy 99.40\n",
            "Epoch 0030  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0031  Loss 0.00  Accuracy 99.44\n",
            "Epoch 0032  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0033  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0034  Loss 0.00  Accuracy 99.39\n",
            "Epoch 0035  Loss 0.00  Accuracy 99.22\n",
            "Epoch 0036  Loss 0.00  Accuracy 99.26\n",
            "Epoch 0037  Loss 0.00  Accuracy 99.47\n",
            "Epoch 0038  Loss 0.00  Accuracy 99.18\n",
            "Epoch 0039  Loss 0.00  Accuracy 99.39\n",
            "Epoch 0040  Loss 0.00  Accuracy 99.44\n",
            "Epoch 0041  Loss 0.00  Accuracy 99.43\n",
            "Epoch 0042  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0043  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0044  Loss 0.00  Accuracy 99.53\n",
            "Epoch 0045  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0046  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0047  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0048  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0049  Loss 0.00  Accuracy 99.35\n",
            "Epoch 0050  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0051  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0052  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0053  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0054  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0055  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0056  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0057  Loss 0.00  Accuracy 99.41\n",
            "Epoch 0058  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0059  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0060  Loss 0.00  Accuracy 99.47\n",
            "Epoch 0061  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0062  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0063  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0064  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0065  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0066  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0067  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0068  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0069  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0070  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0071  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0072  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0073  Loss 0.00  Accuracy 99.43\n",
            "Epoch 0074  Loss 0.00  Accuracy 99.53\n",
            "Epoch 0075  Loss 0.00  Accuracy 99.47\n",
            "Epoch 0076  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0077  Loss 0.00  Accuracy 99.55\n",
            "Epoch 0078  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0079  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0080  Loss 0.00  Accuracy 98.78\n",
            "Epoch 0081  Loss 0.00  Accuracy 99.16\n",
            "Epoch 0082  Loss 0.00  Accuracy 99.40\n",
            "Epoch 0083  Loss 0.00  Accuracy 99.35\n",
            "Epoch 0084  Loss 0.00  Accuracy 99.32\n",
            "Epoch 0085  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0086  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0087  Loss 0.00  Accuracy 99.56\n",
            "Epoch 0088  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0089  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0090  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0091  Loss 0.00  Accuracy 99.45\n",
            "Epoch 0092  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0093  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0094  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0095  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0096  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0097  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0098  Loss 0.00  Accuracy 99.53\n",
            "Epoch 0099  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0100  Loss 0.00  Accuracy 99.53\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qi6kHw1ycwIq",
        "colab_type": "text"
      },
      "source": [
        "## Training with PyTorch data processing API\n",
        "\n",
        "One of the pain points for ML researchers/practitioners when building a new ML model is the data processing. Here, we demonstrate how to use data processing API of [PyTorch](https://pytorch.org/) to train a model with Objax. Different deep learning library comes with different data processing APIs, and depending on your preference, you can choose an API and easily combine with Objax.\n",
        "\n",
        "Similarly, we prepare an `MNIST` dataset and apply the same data preprocessing."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Xv1enS4rd4FD",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 386,
          "referenced_widgets": [
            "1b6f39faaa8544c49e9b84ecfd67ca41",
            "7ddf2ce051e14b249a7b8dfdcd2b08c1",
            "7e948458c6f348adb9b5dcc19594896d",
            "692f5b26181b402690cd682f4ef389ee",
            "012265f6cb7e4516a8a6cf5318a30ce2",
            "e0aff2702cdd47f795f2d2a233ba8248",
            "90782a422602433fb529a31828ea711a",
            "769e755461c64b16a37b279cbf02cb5b",
            "dfe3ba30ba8846ad9a533089bfee6987",
            "b8462eea1e2b49df87b160453da3fb8e",
            "07918b13b9e64c219ca8e2f07c1193be",
            "1eb02d1769b0438582cae93bf765327e",
            "d521e62b249d470b863a70ac156341f6",
            "83ef5b03b73c45b3b5bb8c18eabc72de",
            "787f5b89f95c4255aef60fa6f91d3b22",
            "07d4d1e25e2a41ec97196a1a289199c5",
            "2d599bc13c8543da9ec1c82b11ab6a69",
            "464e56783ced4a039c4f5fcc1fd290e1",
            "678e744efdf9417f8d55151b76bbece1",
            "263e5f6844bd4f7c81876d4093755c63",
            "0cc06bcc77df4921b6b8d9007b42ac05",
            "38a4bb4a15e3405a8007e9691f69d825",
            "f7fb8781ff2646a9bd86440877ecc101",
            "6998ba54432c4eb38bab6792887c7ccc",
            "035b97b7a7e74191ac92016c697ba375",
            "726c4cf5f5b64437a6327c48236aeb42",
            "18fa0b70765842b2abf83c1b4793816c",
            "ade571a07230429d9f5ee996b9fd8760",
            "82a261efb72447feae936e85cd4bbc72",
            "75ecd0780a0b435eb91ddb9c11dbf0d5",
            "db052fa05c2d4224a87e2df722895b24",
            "8c682e975d654977bc80c939cd4a40d8"
          ]
        },
        "outputId": "b26995aa-0b33-4f9a-a1e7-dbd57421427c"
      },
      "source": [
        "import torch\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "transform=transforms.Compose([\n",
        "                              transforms.Pad((2,2,2,2), 0),\n",
        "                              transforms.ToTensor(),\n",
        "                              transforms.Lambda(lambda x: np.concatenate([x] * 3, axis=0)),\n",
        "                              transforms.Lambda(lambda x: x * 2 - 1)\n",
        "                              ])\n",
        "train_dataset = datasets.MNIST(os.environ['HOME'], train=True, download=True, transform=transform)\n",
        "test_dataset = datasets.MNIST(os.environ['HOME'], train=False, download=True, transform=transform)\n",
        "\n",
        "# Define data loader\n",
        "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)\n",
        "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/MNIST/raw/train-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1b6f39faaa8544c49e9b84ecfd67ca41",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting /root/MNIST/raw/train-images-idx3-ubyte.gz to /root/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/MNIST/raw/train-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "dfe3ba30ba8846ad9a533089bfee6987",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting /root/MNIST/raw/train-labels-idx1-ubyte.gz to /root/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/MNIST/raw/t10k-images-idx3-ubyte.gz\n",
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2d599bc13c8543da9ec1c82b11ab6a69",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting /root/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "035b97b7a7e74191ac92016c697ba375",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting /root/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw\n",
            "Processing...\n",
            "Done!\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)\n",
            "  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yNG2rcmylpih",
        "colab_type": "text"
      },
      "source": [
        "We replace data processing pipeline of the train and test loop with `train_loader` and `test_loader` and that's it!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "94E64UlHd0zu",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9edf30c0-781d-4224-b341-325c59a246c6"
      },
      "source": [
        "# Train loop\n",
        "\n",
        "def train_model_with_torch_data_api(model):\n",
        "\n",
        "  def predict(model, x):\n",
        "    \"\"\"\"\"\" \n",
        "    return objax.functional.softmax(model(x,  training=False))\n",
        "    \n",
        "  def flatten_image(x):\n",
        "    \"\"\"Flatten the image before passing it to the DNN.\"\"\"\n",
        "    if isinstance(model, DNNet):\n",
        "      return objax.functional.flatten(x)\n",
        "    else:\n",
        "      return x\n",
        "  \n",
        "  opt = objax.optimizer.Momentum(model.vars())\n",
        "\n",
        "  # Cross Entropy Loss\n",
        "  def loss(x, label):\n",
        "    return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()\n",
        "\n",
        "  gv = objax.GradValues(loss, model.vars())\n",
        "  def train_op(x, label):\n",
        "    g, v = gv(x, label)  # returns gradients, loss\n",
        "    opt(lr, g)\n",
        "    return v\n",
        "\n",
        "  train_op = objax.Jit(train_op, gv.vars() + opt.vars())  \n",
        "  \n",
        "  for epoch in range(epochs):\n",
        "    avg_loss = 0\n",
        "    tot_data = 0\n",
        "    for _, (img, label) in enumerate(train_loader):\n",
        "      avg_loss += float(train_op(flatten_image(img.numpy()), label.numpy())[0]) * len(img)\n",
        "      tot_data += len(img)\n",
        "    avg_loss /= tot_data\n",
        "\n",
        "    # Eval\n",
        "    accuracy = 0\n",
        "    tot_data = 0\n",
        "    for _, (img, label) in enumerate(test_loader):\n",
        "      accuracy += (np.argmax(predict(model, flatten_image(img.numpy())), axis=1) == label.numpy()).sum()\n",
        "      tot_data += len(img)\n",
        "    accuracy /= tot_data\n",
        "    print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eaUy1lXokCL0",
        "colab_type": "text"
      },
      "source": [
        "## Training the DNN Model with PyTorch data API"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fWzkHDNJj_QH",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "59c70897-d048-4ef9-de55-da559cdd77bf"
      },
      "source": [
        "dnn_layer_sizes = 3072, 128, 10\n",
        "dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)\n",
        "train_model_with_torch_data_api(dnn_model)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0001  Loss 2.57  Accuracy 34.35\n",
            "Epoch 0002  Loss 1.93  Accuracy 58.51\n",
            "Epoch 0003  Loss 1.32  Accuracy 68.46\n",
            "Epoch 0004  Loss 0.83  Accuracy 80.95\n",
            "Epoch 0005  Loss 0.62  Accuracy 84.74\n",
            "Epoch 0006  Loss 0.53  Accuracy 86.53\n",
            "Epoch 0007  Loss 0.48  Accuracy 84.18\n",
            "Epoch 0008  Loss 0.45  Accuracy 88.42\n",
            "Epoch 0009  Loss 0.42  Accuracy 87.34\n",
            "Epoch 0010  Loss 0.40  Accuracy 89.29\n",
            "Epoch 0011  Loss 0.39  Accuracy 89.31\n",
            "Epoch 0012  Loss 0.38  Accuracy 89.86\n",
            "Epoch 0013  Loss 0.37  Accuracy 89.91\n",
            "Epoch 0014  Loss 0.36  Accuracy 86.94\n",
            "Epoch 0015  Loss 0.36  Accuracy 89.89\n",
            "Epoch 0016  Loss 0.35  Accuracy 90.12\n",
            "Epoch 0017  Loss 0.34  Accuracy 90.40\n",
            "Epoch 0018  Loss 0.34  Accuracy 90.31\n",
            "Epoch 0019  Loss 0.34  Accuracy 90.79\n",
            "Epoch 0020  Loss 0.33  Accuracy 90.71\n",
            "Epoch 0021  Loss 0.33  Accuracy 90.70\n",
            "Epoch 0022  Loss 0.33  Accuracy 90.69\n",
            "Epoch 0023  Loss 0.33  Accuracy 90.91\n",
            "Epoch 0024  Loss 0.32  Accuracy 90.92\n",
            "Epoch 0025  Loss 0.32  Accuracy 91.06\n",
            "Epoch 0026  Loss 0.32  Accuracy 91.19\n",
            "Epoch 0027  Loss 0.32  Accuracy 91.31\n",
            "Epoch 0028  Loss 0.31  Accuracy 91.31\n",
            "Epoch 0029  Loss 0.31  Accuracy 91.20\n",
            "Epoch 0030  Loss 0.31  Accuracy 91.31\n",
            "Epoch 0031  Loss 0.31  Accuracy 91.36\n",
            "Epoch 0032  Loss 0.31  Accuracy 91.42\n",
            "Epoch 0033  Loss 0.30  Accuracy 91.27\n",
            "Epoch 0034  Loss 0.31  Accuracy 91.47\n",
            "Epoch 0035  Loss 0.30  Accuracy 91.57\n",
            "Epoch 0036  Loss 0.30  Accuracy 91.44\n",
            "Epoch 0037  Loss 0.30  Accuracy 91.55\n",
            "Epoch 0038  Loss 0.30  Accuracy 91.56\n",
            "Epoch 0039  Loss 0.29  Accuracy 91.75\n",
            "Epoch 0040  Loss 0.29  Accuracy 91.69\n",
            "Epoch 0041  Loss 0.29  Accuracy 91.60\n",
            "Epoch 0042  Loss 0.29  Accuracy 91.77\n",
            "Epoch 0043  Loss 0.29  Accuracy 91.76\n",
            "Epoch 0044  Loss 0.29  Accuracy 91.84\n",
            "Epoch 0045  Loss 0.28  Accuracy 92.05\n",
            "Epoch 0046  Loss 0.28  Accuracy 91.78\n",
            "Epoch 0047  Loss 0.28  Accuracy 92.01\n",
            "Epoch 0048  Loss 0.28  Accuracy 91.95\n",
            "Epoch 0049  Loss 0.28  Accuracy 90.11\n",
            "Epoch 0050  Loss 0.28  Accuracy 92.14\n",
            "Epoch 0051  Loss 0.28  Accuracy 92.03\n",
            "Epoch 0052  Loss 0.27  Accuracy 92.29\n",
            "Epoch 0053  Loss 0.27  Accuracy 92.17\n",
            "Epoch 0054  Loss 0.27  Accuracy 92.12\n",
            "Epoch 0055  Loss 0.27  Accuracy 92.34\n",
            "Epoch 0056  Loss 0.27  Accuracy 92.32\n",
            "Epoch 0057  Loss 0.27  Accuracy 92.47\n",
            "Epoch 0058  Loss 0.27  Accuracy 92.38\n",
            "Epoch 0059  Loss 0.27  Accuracy 92.39\n",
            "Epoch 0060  Loss 0.26  Accuracy 92.51\n",
            "Epoch 0061  Loss 0.27  Accuracy 92.50\n",
            "Epoch 0062  Loss 0.26  Accuracy 92.46\n",
            "Epoch 0063  Loss 0.26  Accuracy 92.65\n",
            "Epoch 0064  Loss 0.26  Accuracy 92.57\n",
            "Epoch 0065  Loss 0.26  Accuracy 92.63\n",
            "Epoch 0066  Loss 0.26  Accuracy 92.75\n",
            "Epoch 0067  Loss 0.26  Accuracy 92.57\n",
            "Epoch 0068  Loss 0.26  Accuracy 92.88\n",
            "Epoch 0069  Loss 0.25  Accuracy 92.53\n",
            "Epoch 0070  Loss 0.25  Accuracy 92.80\n",
            "Epoch 0071  Loss 0.25  Accuracy 92.71\n",
            "Epoch 0072  Loss 0.25  Accuracy 92.75\n",
            "Epoch 0073  Loss 0.25  Accuracy 92.84\n",
            "Epoch 0074  Loss 0.25  Accuracy 92.71\n",
            "Epoch 0075  Loss 0.25  Accuracy 92.95\n",
            "Epoch 0076  Loss 0.25  Accuracy 92.82\n",
            "Epoch 0077  Loss 0.25  Accuracy 92.90\n",
            "Epoch 0078  Loss 0.25  Accuracy 92.87\n",
            "Epoch 0079  Loss 0.25  Accuracy 89.55\n",
            "Epoch 0080  Loss 0.25  Accuracy 92.86\n",
            "Epoch 0081  Loss 0.24  Accuracy 92.99\n",
            "Epoch 0082  Loss 0.24  Accuracy 93.03\n",
            "Epoch 0083  Loss 0.24  Accuracy 93.03\n",
            "Epoch 0084  Loss 0.24  Accuracy 93.01\n",
            "Epoch 0085  Loss 0.24  Accuracy 93.13\n",
            "Epoch 0086  Loss 0.24  Accuracy 93.17\n",
            "Epoch 0087  Loss 0.24  Accuracy 92.87\n",
            "Epoch 0088  Loss 0.24  Accuracy 92.93\n",
            "Epoch 0089  Loss 0.24  Accuracy 93.16\n",
            "Epoch 0090  Loss 0.24  Accuracy 93.38\n",
            "Epoch 0091  Loss 0.24  Accuracy 92.98\n",
            "Epoch 0092  Loss 0.24  Accuracy 93.30\n",
            "Epoch 0093  Loss 0.23  Accuracy 93.09\n",
            "Epoch 0094  Loss 0.23  Accuracy 93.19\n",
            "Epoch 0095  Loss 0.23  Accuracy 93.25\n",
            "Epoch 0096  Loss 0.23  Accuracy 93.22\n",
            "Epoch 0097  Loss 0.23  Accuracy 93.28\n",
            "Epoch 0098  Loss 0.23  Accuracy 93.39\n",
            "Epoch 0099  Loss 0.23  Accuracy 93.25\n",
            "Epoch 0100  Loss 0.23  Accuracy 93.30\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o7MkDt3LkZio",
        "colab_type": "text"
      },
      "source": [
        "## Training the ConvNet Model with PyTorch data API"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uWADCELWkF_x",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "02d93686-d97d-4265-f4ce-2cfd0b3b4249"
      },
      "source": [
        "cnn_model = ConvNet(nin=3, nclass=10)\n",
        "print(cnn_model.vars())\n",
        "train_model_with_torch_data_api(cnn_model)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(ConvNet).conv_block1(Sequential)[0](Conv2D).w                      432 (3, 3, 3, 16)\n",
            "(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean       16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var        16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta               16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma              16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[3](Conv2D).w                     2304 (3, 3, 16, 16)\n",
            "(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean       16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var        16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta               16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma              16 (1, 16, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[0](Conv2D).w                     4608 (3, 3, 16, 32)\n",
            "(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean       32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var        32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta               32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma              32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[3](Conv2D).w                     9216 (3, 3, 32, 32)\n",
            "(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean       32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var        32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta               32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma              32 (1, 32, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[0](Conv2D).w                    18432 (3, 3, 32, 64)\n",
            "(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean       64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var        64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta               64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma              64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[3](Conv2D).w                    36864 (3, 3, 64, 64)\n",
            "(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean       64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var        64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta               64 (1, 64, 1, 1)\n",
            "(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma              64 (1, 64, 1, 1)\n",
            "(ConvNet).linear(Linear).b                                           10 (10,)\n",
            "(ConvNet).linear(Linear).w                                          640 (64, 10)\n",
            "+Total(32)                                                        73402\n",
            "Epoch 0001  Loss 0.26  Accuracy 24.18\n",
            "Epoch 0002  Loss 0.05  Accuracy 37.53\n",
            "Epoch 0003  Loss 0.03  Accuracy 42.17\n",
            "Epoch 0004  Loss 0.03  Accuracy 73.50\n",
            "Epoch 0005  Loss 0.02  Accuracy 80.33\n",
            "Epoch 0006  Loss 0.02  Accuracy 83.28\n",
            "Epoch 0007  Loss 0.02  Accuracy 90.87\n",
            "Epoch 0008  Loss 0.01  Accuracy 98.77\n",
            "Epoch 0009  Loss 0.01  Accuracy 98.42\n",
            "Epoch 0010  Loss 0.01  Accuracy 98.16\n",
            "Epoch 0011  Loss 0.01  Accuracy 98.74\n",
            "Epoch 0012  Loss 0.01  Accuracy 95.05\n",
            "Epoch 0013  Loss 0.01  Accuracy 98.89\n",
            "Epoch 0014  Loss 0.00  Accuracy 98.70\n",
            "Epoch 0015  Loss 0.00  Accuracy 99.01\n",
            "Epoch 0016  Loss 0.00  Accuracy 98.97\n",
            "Epoch 0017  Loss 0.00  Accuracy 98.79\n",
            "Epoch 0018  Loss 0.00  Accuracy 98.37\n",
            "Epoch 0019  Loss 0.00  Accuracy 99.19\n",
            "Epoch 0020  Loss 0.00  Accuracy 99.22\n",
            "Epoch 0021  Loss 0.00  Accuracy 98.43\n",
            "Epoch 0022  Loss 0.00  Accuracy 99.02\n",
            "Epoch 0023  Loss 0.00  Accuracy 99.38\n",
            "Epoch 0024  Loss 0.00  Accuracy 99.42\n",
            "Epoch 0025  Loss 0.00  Accuracy 99.45\n",
            "Epoch 0026  Loss 0.00  Accuracy 99.35\n",
            "Epoch 0027  Loss 0.00  Accuracy 99.42\n",
            "Epoch 0028  Loss 0.00  Accuracy 99.42\n",
            "Epoch 0029  Loss 0.00  Accuracy 99.14\n",
            "Epoch 0030  Loss 0.00  Accuracy 99.33\n",
            "Epoch 0031  Loss 0.00  Accuracy 99.36\n",
            "Epoch 0032  Loss 0.00  Accuracy 99.18\n",
            "Epoch 0033  Loss 0.00  Accuracy 99.43\n",
            "Epoch 0034  Loss 0.00  Accuracy 99.47\n",
            "Epoch 0035  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0036  Loss 0.00  Accuracy 99.53\n",
            "Epoch 0037  Loss 0.00  Accuracy 99.38\n",
            "Epoch 0038  Loss 0.00  Accuracy 99.39\n",
            "Epoch 0039  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0040  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0041  Loss 0.00  Accuracy 99.47\n",
            "Epoch 0042  Loss 0.00  Accuracy 99.54\n",
            "Epoch 0043  Loss 0.00  Accuracy 99.35\n",
            "Epoch 0044  Loss 0.00  Accuracy 99.45\n",
            "Epoch 0045  Loss 0.00  Accuracy 99.47\n",
            "Epoch 0046  Loss 0.00  Accuracy 99.53\n",
            "Epoch 0047  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0048  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0049  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0050  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0051  Loss 0.00  Accuracy 99.45\n",
            "Epoch 0052  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0053  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0054  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0055  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0056  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0057  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0058  Loss 0.00  Accuracy 99.44\n",
            "Epoch 0059  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0060  Loss 0.00  Accuracy 99.26\n",
            "Epoch 0061  Loss 0.00  Accuracy 93.99\n",
            "Epoch 0062  Loss 0.00  Accuracy 97.80\n",
            "Epoch 0063  Loss 0.00  Accuracy 80.26\n",
            "Epoch 0064  Loss 0.00  Accuracy 99.20\n",
            "Epoch 0065  Loss 0.00  Accuracy 99.38\n",
            "Epoch 0066  Loss 0.00  Accuracy 99.44\n",
            "Epoch 0067  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0068  Loss 0.00  Accuracy 99.45\n",
            "Epoch 0069  Loss 0.00  Accuracy 99.42\n",
            "Epoch 0070  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0071  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0072  Loss 0.00  Accuracy 99.44\n",
            "Epoch 0073  Loss 0.00  Accuracy 99.41\n",
            "Epoch 0074  Loss 0.00  Accuracy 99.46\n",
            "Epoch 0075  Loss 0.00  Accuracy 99.42\n",
            "Epoch 0076  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0077  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0078  Loss 0.00  Accuracy 99.56\n",
            "Epoch 0079  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0080  Loss 0.00  Accuracy 99.42\n",
            "Epoch 0081  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0082  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0083  Loss 0.00  Accuracy 99.44\n",
            "Epoch 0084  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0085  Loss 0.00  Accuracy 99.53\n",
            "Epoch 0086  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0087  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0088  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0089  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0090  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0091  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0092  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0093  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0094  Loss 0.00  Accuracy 99.50\n",
            "Epoch 0095  Loss 0.00  Accuracy 99.55\n",
            "Epoch 0096  Loss 0.00  Accuracy 99.48\n",
            "Epoch 0097  Loss 0.00  Accuracy 99.51\n",
            "Epoch 0098  Loss 0.00  Accuracy 99.52\n",
            "Epoch 0099  Loss 0.00  Accuracy 99.49\n",
            "Epoch 0100  Loss 0.00  Accuracy 99.52\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r9WgeoI3wo-o",
        "colab_type": "text"
      },
      "source": [
        "## What's Next\n",
        "\n",
        "We have learned how to use existing models and define new models to classify MNIST. Next, you can read one or more of the in-depth topics or browse through the Objax's APIs."
      ]
    }
  ]
}